Skip to content

[KV Offload] Fix multi-node KV offloading state desynchronization and JAX dispatch…#2983

Draft
amitkumar307d wants to merge 2 commits into
vllm-project:mainfrom
amitkumar307d:multinode-kv-offloading
Draft

[KV Offload] Fix multi-node KV offloading state desynchronization and JAX dispatch…#2983
amitkumar307d wants to merge 2 commits into
vllm-project:mainfrom
amitkumar307d:multinode-kv-offloading

Conversation

@amitkumar307d

@amitkumar307d amitkumar307d commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Description

This PR fixes state desynchronization, JAX dispatch deadlocks, and resource leaks in Multi-Node KV Offloading, ensuring stability during high-concurrency benchmarks.

Context & Problem:
Previously, running high-concurrency benchmarks (like prefix_repetition) on large models (e.g., Qwen3-Coder 480B) with multi-node KV offloading enabled led to fatal crashes (AssertionError) and runtime hangs.

  1. State Desynchronization: In multi-node setups, completion signals (finished chunk IDs) from worker nodes were being discarded in the stats aggregation phase, causing the Scheduler to miss confirmations and desync buffer accounting.
  2. JAX Dispatch Deadlocks: Running jax.device_put (D2H) in background threads caused non-deterministic dispatch orders across nodes, leading to low-level TPU runtime halts.
  3. Resource Leaks: Duplicate hashes in a batch resulted in multiple physical CPU chunks being allocated for the same content, wasting CPU RAM.
  4. Data Corruption Races (RAW/WAW):
    • Read-After-Write (RAW): A fast host could trigger a Load on a slow host before the slow host finished its Save in the background.
    • Write-After-Write (WAW): Rapid slot recycling caused background threads to race when writing to the same CPU memory address, potentially overwriting newer data with older data.
  5. Thread Safety Gaps: Concurrent modifications to the OrderedDict backing the LocalCPUBackend caused structural corruption under high load.

Solution

  • Robust Distributed State Synchronization:
    • Implemented KVOffloadConnectorStats.aggregate using collections.Counter and the union (|) operator to correctly merge completion signals across all worker nodes without artificially multiplying chunk counts.
    • Fixed KVOffloadConnectorStats.num_finished_blocks to accurately count total finished chunks instead of just request entries.
  • Deterministic JAX Dispatch Order: Moved jax.device_put (D2H dispatch) from background threads to the main thread in TPUOffloadConnectorWorker.start_save_kv (and batched version) to ensure globally aligned dispatch order across all TPU nodes.
  • Resource Leak Prevention: Added deduplication of chunk_hashes in OffloadManager.allocate_for_save to ensure only one CPUChunk is allocated per unique content hash in a batch.
  • Resilient Bookkeeping Cleanup: Relaxed strict assertions and replaced remove() with discard() in update_connector_output to handle redundant or late completion signals gracefully without crashing the engine.
  • Concurrency Hardening & Race Mitigation:
    • Local Read-Before-Write Sync: Worker main thread tracks _local_in_flight_saves and blocks (future.result()) during start_load_kv if a Load is requested for a chunk still being saved locally.
    • Dependency-Chained Serialization (WAW Fix): Background saves targeting the exact same CPU slot are chained via future dependencies in the thread pool, enforcing strict FIFO order of writes without requiring global central barriers.
    • Thread-Safe CPU Cache: Secured LocalCPUBackend by wrapping critical sections (add, get, reclaim) with threading.Lock to protect the underlying dictionary from concurrent corruption.

Tests

  • Verified with Qwen3-Coder 480B on a multi-node TPU cluster running prefix_repetition benchmarks under high concurrency.
  • Confirmed that the AssertionError state desync is resolved and the server runs to completion without JAX runtime hangs.

Qwen3-Coder 480B - Server command:

python3 -m vllm.entrypoints.openai.api_server \
--host=0.0.0.0 \
--port=8000 \
--tensor-parallel-size=16 \
--max-model-len=102400 \
--load-format=runai_streamer \
--kv-cache-dtype=fp8 \
--gpu-memory-utilization=0.8 \
--data-parallel-size=1 \
--max-num-batched-tokens=16384 \
--max-num-seqs=512 \
--model=Qwen/Qwen3-Coder-480B-A35B-Instruct \
--served-model-name=Qwen/Qwen3-Coder-480B-A35B-Instruct \
--enable-prefix-caching \
--async-scheduling \
--enable-expert-parallel \
--kv-transfer-config='{"kv_connector": "TPUOffloadConnector", "kv_connector_module_path": "tpu_inference.offload.tpu_offload_connector","kv_role": "kv_both", "kv_connector_extra_config": {"cpu_bytes_to_use": 107374182400, "lazy_offload": false}}'

Qwen3-Coder-480B - Client Command:

vllm bench serve   --backend=openai   --model=Qwen/Qwen3-Coder-480B-A35B-Instruct   --dataset-name=prefix_repetition   --host=localhost   --port=8000   --seed=123   --num-prompts=32   --max-concurrency=32   
--prefix-repetition-prefix-len=19424   --prefix-repetition-suffix-len=32   --prefix-repetition-output-len=1024   --prefix-repetition-num-prefixes=4   --percentile-metrics='ttft,tpot,itl,e2el'   --ignore-eos

Logs:

(APIServer pid=246100) INFO 06-24 08:44:22 [loggers.py:273] Engine 000: Avg prompt throughput: 3276.5 tokens/s, Avg generation throughput: 1095.4 tokens/s, Running: 32 reqs, Waiting: 0 reqs, GPU KV cache usage: 7.0%, Prefix cache hit rate: 85.2%, External prefix cache hit rate: 0.0%
(APIServer pid=246100) INFO 06-24 08:44:22 [metrics.py:103] KV Transfer metrics: Num finished save chunks =72, Num finished load chunks=0
(APIServer pid=246100) INFO 06-24 08:44:32 [loggers.py:273] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 1302.3 tokens/s, Running: 32 reqs, Waiting: 0 reqs, GPU KV cache usage: 8.1%, Prefix cache hit rate: 85.2%, External prefix cache hit rate: 0.0%
(APIServer pid=246100) INFO:     127.0.0.1:60682 - "GET /metrics HTTP/1.1" 200 OK
(APIServer pid=246100) INFO:     127.0.0.1:60688 - "GET /metrics HTTP/1.1" 200 OK

Results:

============ Serving Benchmark Result ============
Successful requests:                     256       
Failed requests:                         0         
Maximum request concurrency:             256       
Benchmark duration (s):                  476.23    
Total input tokens:                      25600009  
Total generated tokens:                  262144    
Request throughput (req/s):              0.54      
Output token throughput (tok/s):         550.45    
Peak output token throughput (tok/s):    768.00    
Peak concurrent requests:                256.00    
Total token throughput (tok/s):          54305.47  
---------------Time to First Token----------------
Mean TTFT (ms):                          40361.73  
Median TTFT (ms):                        40361.00  
P99 TTFT (ms):                           78277.68  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          396.16    
Median TPOT (ms):                        402.93    
P99 TPOT (ms):                           411.61    
---------------Inter-token Latency----------------
Mean ITL (ms):                           396.17    
Median ITL (ms):                         443.49    
P99 ITL (ms):                            489.86    
----------------End-to-end Latency----------------
Mean E2EL (ms):                          445629.55 
Median E2EL (ms):                        461372.58 
P99 E2EL (ms):                           475560.97 
==================================================

Checklist

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@amitkumar307d amitkumar307d force-pushed the multinode-kv-offloading branch 6 times, most recently from 10e7c53 to 2c9a4f4 Compare June 26, 2026 18:24
@amitkumar307d amitkumar307d changed the title Fix multi-node KV offloading state desynchronization and JAX dispatch… [KV Offload] Fix multi-node KV offloading state desynchronization and JAX dispatch… Jun 26, 2026
… order

Signed-off-by: Amit Kumar <amitmkumar@google.com>
@amitkumar307d amitkumar307d force-pushed the multinode-kv-offloading branch from 3dd4d98 to a53270c Compare June 26, 2026 19:04
Signed-off-by: Amit Kumar <amitmkumar@google.com>
@amitkumar307d amitkumar307d force-pushed the multinode-kv-offloading branch from b165061 to ae9438f Compare June 26, 2026 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant